import anndata as ad
import hdf5plugin
import scipy
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.functional as F
import time
import numpy as np
import pandas as pd
import math
import scanpy as sc
import wandb
import argparse
from tqdm import tqdm
from copy import deepcopy
from CellBert.utils.eval import minimum_eval, clustering_eval
from CellBert.utils.data import XDict, clean_batches, balanced_partition, data_setup
from CellBert.model import OmicsFormer
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.multiprocessing import Process, Manager
import pickle
import os
import json
# wandb.login()

def corr(y_true, y_pred):
    y_true_c = y_true - torch.mean(y_true, 1)[:, None]
    y_pred_c = y_pred - torch.mean(y_pred, 1)[:, None]
    pearson = torch.mean(torch.sum(y_true_c * y_pred_c, 1) / torch.sqrt(torch.sum(y_true_c * y_true_c, 1)) / torch.sqrt(
        torch.sum(y_pred_c * y_pred_c, 1)))
    return pearson

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    mp.set_sharing_strategy('file_system')
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def train(rank, world_size, config, shared_data, tune_flag, dataset_list, gene_dict, partitions, val_num):
    # rank, world_size, config, shared_data, tune_flag, dataset_list, gene_dict, partitions, val_num = args
    batch_list, seq_list, order_list, coord_list, label_list = shared_data
    setup(rank, world_size)
    partition = partitions[rank]
    model = OmicsFormer(**config).cuda(rank)
    if rank == 0:
        total_params = 0
        for param_tensor in model.state_dict():
            param = model.state_dict()[param_tensor]
            total_params += torch.numel(param)
        print(total_params)
    model = DDP(model, device_ids=[rank], find_unused_parameters=True)
    optim = torch.optim.AdamW(model.parameters(), lr=config['lr'], weight_decay=config['wd'])
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, mode='min', factor=0.7, patience=5, verbose=True)

    train_loss = []
    valid_loss = []
    for epoch in range(config['epochs']):
        start_time = time.time()
        epoch_loss = []
        model.train()
        for i in partition:
            x = torch.sparse_csr_tensor(seq_list[0][i], seq_list[1][i], seq_list[2][i],
                                    seq_list[3][i].tolist()).to_sparse().float().coalesce()
            x_dict = XDict({'x_seq': x.cuda(rank),  # seq_list[i].cuda(rank),
                            'batch': batch_list[i].cuda(rank),
                            'coord': coord_list[i].cuda(rank),
                            'gene_mask': gene_dict[dataset_list[i]].cuda(rank)})
            out_dict, loss = model(x_dict)
            optim.zero_grad()
            loss.backward()
            optim.step()
            optim.zero_grad()
            epoch_loss.append(loss.item())
            del loss, x_dict, x
        train_loss.append(math.sqrt(sum(epoch_loss) / len(epoch_loss)))
        if epoch>80:
            for param_group in optim.param_groups:
                param_group['lr'] = param_group['lr'] * 0.996
        elif epoch<5:
            for param_group in optim.param_groups:
                param_group['lr'] = config['lr'] * (epoch+1)/5

        if rank == 0 and val_num>0:
            with torch.no_grad():
                model.eval()
                epoch_loss = []

                for i in range(len(batch_list) - val_num, len(batch_list)):
                    x = torch.sparse_csr_tensor(seq_list[0][i], seq_list[1][i],
                                                seq_list[2][i],
                                                seq_list[3][i].tolist()).to_sparse().float().coalesce()
                    x_dict = XDict({'x_seq': x.cuda(rank),  # seq_list[i].cuda(rank),
                                    'batch': batch_list[i].cuda(rank),
                                    'coord': coord_list[i].cuda(rank),
                                    'gene_mask': gene_dict[dataset_list[i]].cuda(rank), })
                    out_dict, loss = model(x_dict)
                    epoch_loss.append(loss.item())
            valid_loss.append(math.sqrt(sum(epoch_loss) / len(epoch_loss)))
            del out_dict, loss, x_dict, x
            torch.cuda.empty_cache()
            # pbar.set_description(f'Epoch {epoch} | Train loss: {train_loss[-1]:.4f} | Valid loss: {valid_loss[-1]:.4f}')
            print(f'Epoch {epoch} | Train loss: {train_loss[-1]:.4f} | Valid loss: {valid_loss[-1]:.4f} | Time: {time.time() - start_time:.2f}')
            if tune_flag:
                wandb.log({"train": train_loss[-1], "valid": valid_loss[-1]})

            if min(valid_loss) == valid_loss[-1]:
                best_model_weights = deepcopy(model.state_dict())
            if epoch > 0 and min(valid_loss[-50:]) != min(valid_loss):
                print('Early stopped.')
                break

    if rank == 0:
        if val_num>0:
            torch.save(best_model_weights, f'{config["name"]}.pt')
            model.load_state_dict(best_model_weights)
            pass
        else:
            torch.save(model.state_dict(), f'{config["name"]}.pt')

        # Inference
        c = []
        model.eval()
        with torch.no_grad():
            for i in range(len(batch_list) - val_num, len(batch_list)):#range(len(batch_list)):
                try:
                    x = torch.sparse_csr_tensor(seq_list[0][i], seq_list[1][i],
                                            seq_list[2][i],
                                            seq_list[3][i].tolist()).to_sparse().float().coalesce()
                    x_dict = XDict({'x_seq': x.cuda(),  # seq_list[i].cuda(rank),
                                    'batch': batch_list[i].cuda(),
                                    'coord': coord_list[i].cuda(),
                                    'gene_mask': gene_dict[dataset_list[i]].cuda(), })
                    y = x_dict['x_seq'].to_dense()
                    out_dict, loss = model(x_dict)
                    c.append(corr(out_dict['recon'], y).cpu().item())
                except:
                    print(i, x.shape)

            del loss, out_dict
        torch.cuda.empty_cache()
        print('Validation Pearson:', sum(c) / len(c))

        c = []
        res = []
        with torch.no_grad():
            for i in range(len(batch_list)):
                try:
                    x = torch.sparse_csr_tensor(seq_list[0][i], seq_list[1][i],
                                            seq_list[2][i],
                                            seq_list[3][i].tolist()).to_sparse().float().coalesce()
                    x_dict = XDict({'x_seq': x.cuda(),  # seq_list[i].cuda(rank),
                                    'batch': batch_list[i].cuda(),
                                    'coord': coord_list[i].cuda(),
                                    'gene_mask': gene_dict[dataset_list[i]].cuda(), })
                    y = x_dict['x_seq'].to_dense()
                    out_dict, loss = model(x_dict)
                    c.append(corr(out_dict['recon'], y).cpu().item())
                    res.append(out_dict['latent'].cpu())
                except:
                    print(i, x.shape)

            del loss, out_dict, model
        torch.cuda.empty_cache()
        print('All Pearson:', sum(c) / len(c))

        res = torch.cat(res, dim=0).numpy()
        data_eval = ad.AnnData(X=res,
                               obs=pd.DataFrame({'batch': torch.cat(batch_list)}))  # data[order]
        data_eval.obs['cell_type'] = torch.cat(label_list).numpy().tolist()
        data_eval.obs['cell_type'] = data_eval.obs['cell_type'].astype('category')
        data_eval.obs['batch'] = data_eval.obs['batch'].astype('category')
        data_eval.obsm['X_cellbert'] = res

        df = minimum_eval(data_eval)
        print(df)
        print(clustering_eval(data_eval))
        if tune_flag:
            wandb.log({'test_pearson': sum(c) / len(c)})
            wandb.log({'graph_conn': df.T['graph_conn'].values[0]})
            wandb.finish()
        del res, y, c, df
        del best_model_weights, data_eval
        torch.cuda.empty_cache()
        cleanup()

def main(config=None):
    global tune_flag
    global gene_list, batch_labels, seq_list, order_list, gene_dict, dataset_list, coord_list, \
        val_num, tune_flag, label_list, gene_list, batch_list
    # mp.set_start_method('spawn')
    # tune_flag = True if config is None else False
    if tune_flag:
        wandb.init(
            # set the wandb project where this run will be logged
            group="gmvae-v1",
        )
        config = wandb.config
    config["batch_num"] = batch_labels.max() + 1
    config['gene_list'] = gene_list
    with open(config['name']+'.config.pkl', 'wb') as f:
        pickle.dump(config, f)
    val_num = config['val_num']
    world_size = torch.cuda.device_count()
    if val_num > 0:
        partitions = balanced_partition(batch_list[:-val_num], world_size)
    else:
        partitions = balanced_partition(batch_list, world_size)
    for i in range(len(batch_list)):
        batch_list[i] = batch_list[i].share_memory_()
        order_list[i] = order_list[i].share_memory_()
        coord_list[i] = coord_list[i].share_memory_()
        label_list[i] = label_list[i].share_memory_()
    for i in range(4):
        for j in range(len(seq_list[i])):
            seq_list[i][j] = seq_list[i][j].share_memory_()
    shared_data = [batch_list, seq_list, order_list, coord_list, label_list]

    if config['ddp']:
        mp.spawn(train, args=(world_size, config, shared_data, tune_flag, dataset_list, gene_dict, partitions, val_num),
             nprocs=world_size, join=True)
    else:
        partitions = balanced_partition(batch_list[:-val_num], 1)
        train(0, 1, config, shared_data, tune_flag, dataset_list, gene_dict, partitions, val_num)
    # mp.spawn(train, args=(world_size, config, partitions, shared_data), nprocs=world_size, join=True)

if __name__ == '__main__':
    mp.set_sharing_strategy('file_system')
    mp.set_start_method('spawn', force=True)
    parser = argparse.ArgumentParser()
    parser.add_argument("--tune", action='store_true')
    args = parser.parse_args()
    dataset_name = 'CellBert_v1'#'HLCA_zstd'#'CellBert_subset' #

    if dataset_name in ['CellBert_v0', 'CellBert_subset']: # datasets haven't been prerpocessed
        data = ad.read_h5ad(f'/home/ec2-user/Project/new_integration/{dataset_name}.h5ad')
        gene_list = data.var.index.to_list()
        gene_to_idx = {gene_list[i]: i for i in range(len(gene_list))}
        with open(f'{dataset_name}.gene.json') as f:
            gene_dict = json.load(f)
            new_gene_dict = {}
            for k in gene_dict.keys():
                gene_mask = torch.zeros(len(gene_list)).int()
                gene_mask[[gene_to_idx[gene] for gene in gene_dict[k]]] = 1
                new_gene_dict[k] = gene_mask.bool()
            gene_dict = new_gene_dict
        # data = ad.read_h5ad('../benchmark_inner.h5ad')
        data.obs['batch'] = data.obs['batch_label']
        data = clean_batches(data)
        sc.pp.normalize_total(data, target_sum=1e4)
        sc.pp.log1p(data)
        print(data.shape)
    elif dataset_name == 'HLCA_zstd':
        data = ad.read_h5ad('HLCA_zstd.h5ad')
        # data = ad.read_h5ad(f'{dataset_name}.h5ad')
        gene_list = data.var.index.to_list()
        data.obs['platform'] = 'scRNA-seq'
        data.obs['Dataset'] = data.obs['dataset_name']
        gene_dict = dict(zip(data.obs['Dataset'].unique(), [torch.ones(len(gene_list)).bool()] * data.obs['Dataset'].nunique()))
        print(data.shape)
        data = clean_batches(data)
        sc.pp.normalize_total(data, target_sum=1e4)
        sc.pp.log1p(data)
    else: # prerpocessed datasets
        data = ad.read_h5ad(f'{dataset_name}.h5ad')# , backed='r') # backed can run but too slow
        with open(f'{dataset_name}.gene.pkl', 'rb') as f:
            gene_dict = pickle.load(f)
        gene_list = data.var.index.to_list()
        print(data.shape)

    seq_list, batch_list, batch_labels, order_list, dataset_list, coord_list, label_list = data_setup(data)
    out_dim = len(gene_list)
    del data

    if args.tune:
        tune_flag = True
        sweep_configuration = {
            'method': 'bayes',
            'name': 'tuning-gmvae',
            'metric': {
                'goal': 'maximize',
                'name': 'graph_conn'
            },
            'parameters': {
                # "enc_mod": {'values': ['cosformer', 'performer', 'mlp']},
                "enc_mod": {'values': ['performer']},
                # "enc_hid": {'values': [128, 256, 512, 64]},
                "enc_hid": {'values': [256]},
                # "enc_hid": {'values': [512]},
                "enc_layers": {'values': [2, 3]},
                "post_latent_dim": {'values': [64, 32, 16]},
                "dec_mod": {'values': ['mlp']},
                "dec_hid": {'values': [128, 64, 256]},
                "dec_layers": {'values': [2]},
                "model_dropout": {'values': [0.1, 0.2, 0.3]},
                "mask_node_rate": {'values': [0.1, 0.3, 0.5, 0.7]},
                "mask_feature_rate": {'values': [0.1, 0.3, 0.5, 0.7]},
                "drop_node_rate": {'values': [0., 0.2, 0.4]},
                # "cat_dim": {'values': [16, 32, 64]},
                # "conti_dim": {'values': [16, 32, 64]},
                "dataset": {'values': ["HLCA"]},
                "architecture": {'values': ["OmicsFormer"]},
                "epochs": {'values': [500]},
                "lr": {'values': [2e-4]},
                "wd": {'values': [0, 1e-8]},
                # "conti_l1": {'values': [0, 1e-1, 1e-2]},
                # "conti_l2": {'values': [0, 1e-1, 1e-2]},
                "gumbel_softmax": {'values': [False]},
                "num_clusters": {'values': [2, 4, 8, 16]},
                "w_li": {'values': [0, 1e-9, 1e-7, 1e-5]},
                "w_en": {'values': [0, 1e-6, 1e-5, 1e-4]},
                "w_ce": {'values': [0, 1e-6, 1e-5, 1e-4]},
                "out_dim": {'values': [out_dim]},
                "ddp": True,
            }
        }
        sweep_id = wandb.sweep(sweep=sweep_configuration, project='CellBert')
        print(sweep_id)
        # wandb.agent(sweep_id=sweep_id, function=main, count=200)
        wandb.agent(sweep_id="s92rhp0o", function=main, count=200)
        # wandb.agent(sweep_id="5ta9v5xr", function=main, count=200)
        # wandb.agent(sweep_id="gyrtmuq6", function=main, count=200)
    else:
        tune_flag = False
        config = {
            "name": "20230515_5M_12M",
            "enc_mod": 'performer',
            "enc_hid": 192,
            "enc_layers": 2,
            "latent_mod": 'gmvae',
            "post_latent_dim": 64,
            "dec_mod": 'mlp',
            "dec_hid": 128,
            "dec_layers": 2,
            "model_dropout": 0.1,
            "mask_node_rate": 0.5,
            "mask_feature_rate": 0.7,
            "drop_node_rate": 0.2,
            # "cat_dim": 16,
            # "conti_dim": 64,
            "dataset": "HLCA",
            "architecture": "OmicsFormer",
            "epochs": 500,
            "lr": 2e-4,
            "wd": 1e-8,
            "w_li": 1e-9,
            "w_en": 1e-4,
            "w_ce": 1e-4,
            "gumbel_softmax": False,
            "num_clusters": 16,
            "out_dim": out_dim,
            "ddp": True,
            'val_num': 2,
        }
        main(config)


